{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "l7Sh70ascFNs"
   },
   "source": [
    "# Solving the wave equation on cloud TPUs\n",
    "\n",
    "[_Stephan Hoyer_](https://twitter.com/shoyer)\n",
    "\n",
    "In this notebook, we solve the 2D [wave equation](https://en.wikipedia.org/wiki/Wave_equation):\n",
    "$$\n",
    "\\frac{\\partial^2 u}{\\partial t^2} = c^2 \\nabla^2 u\n",
    "$$\n",
    "\n",
    "We use a simple [finite difference](https://en.wikipedia.org/wiki/Finite_difference_method) formulation with [Leapfrog time integration](https://en.wikipedia.org/wiki/Leapfrog_integration).\n",
    "\n",
    "Note: It is natural to express finite difference methods as convolutions, but here we intentionally avoid convolutions in favor of array indexing/arithmetic. This is because \"batch\" and \"feature\" dimensions in TPU convolutions are padded to multiples of either 8 and 128, but in our case both these dimensions are effectively of size 1.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "xAd8PvW6ceyk"
   },
   "source": [
    "## Setup required environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "n6L8zIaAIrqj"
   },
   "outputs": [],
   "source": [
    "# Grab other packages for this demo.\n",
    "!pip install -U -q Pillow moviepy proglog scikit-image"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UHkzcAMhcpm_"
   },
   "source": [
    "## Simulation code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "4NJQ8Q5M99wy"
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "import jax\n",
    "from jax import lax\n",
    "from jax import tree_util\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import skimage.filters\n",
    "import proglog\n",
    "from moviepy.editor import ImageSequenceClip\n",
    "\n",
    "device_count = jax.device_count()\n",
    "\n",
    "# Spatial partitioning via halo exchange\n",
    "\n",
    "def send_right(x, axis_name):\n",
    "  # Note: if some devices are omitted from the permutation, lax.ppermute\n",
    "  # provides zeros instead. This gives us an easy way to apply Dirichlet\n",
    "  # boundary conditions.\n",
    "  left_perm = [(i, (i + 1) % device_count) for i in range(device_count - 1)]\n",
    "  return lax.ppermute(x, perm=left_perm, axis_name=axis_name)\n",
    "\n",
    "def send_left(x, axis_name):\n",
    "  left_perm = [((i + 1) % device_count, i) for i in range(device_count - 1)]\n",
    "  return lax.ppermute(x, perm=left_perm, axis_name=axis_name)\n",
    "\n",
    "def axis_slice(ndim, index, axis):\n",
    "  slices = [slice(None)] * ndim\n",
    "  slices[axis] = index\n",
    "  return tuple(slices)\n",
    "\n",
    "def slice_along_axis(array, index, axis):\n",
    "  return array[axis_slice(array.ndim, index, axis)]\n",
    "\n",
    "def tree_vectorize(func):\n",
    "  def wrapper(x, *args, **kwargs):\n",
    "    return tree_util.tree_map(lambda x: func(x, *args, **kwargs), x)\n",
    "  return wrapper\n",
    "\n",
    "@tree_vectorize\n",
    "def halo_exchange_padding(array, padding=1, axis=0, axis_name='x'):\n",
    "  if not padding > 0:\n",
    "    raise ValueError(f'invalid padding: {padding}')\n",
    "  array = jnp.array(array)\n",
    "  if array.ndim == 0:\n",
    "    return array\n",
    "  left = slice_along_axis(array, slice(None, padding), axis)\n",
    "  right = slice_along_axis(array, slice(-padding, None), axis)\n",
    "  right, left = send_left(left, axis_name), send_right(right, axis_name)\n",
    "  return jnp.concatenate([left, array, right], axis)\n",
    "\n",
    "@tree_vectorize\n",
    "def halo_exchange_inplace(array, padding=1, axis=0, axis_name='x'):\n",
    "  left = slice_along_axis(array, slice(padding, 2*padding), axis)\n",
    "  right = slice_along_axis(array, slice(-2*padding, -padding), axis)\n",
    "  right, left = send_left(left, axis_name), send_right(right, axis_name)\n",
    "  array = array.at[axis_slice(array.ndim, slice(None, padding), axis)].set(left)\n",
    "  array = array.at[axis_slice(array.ndim, slice(-padding, None), axis)].set(right)\n",
    "  return array\n",
    "\n",
    "# Reshaping inputs/outputs for pmap\n",
    "\n",
    "def split_with_reshape(array, num_splits, *, split_axis=0, tile_id_axis=None):\n",
    "  if tile_id_axis is None:\n",
    "    tile_id_axis = split_axis\n",
    "  tile_size, remainder = divmod(array.shape[split_axis], num_splits)\n",
    "  if remainder:\n",
    "    raise ValueError('num_splits must equally divide the dimension size')\n",
    "  new_shape = list(array.shape)\n",
    "  new_shape[split_axis] = tile_size\n",
    "  new_shape.insert(split_axis, num_splits)\n",
    "  return jnp.moveaxis(jnp.reshape(array, new_shape), split_axis, tile_id_axis)\n",
    "\n",
    "def stack_with_reshape(array, *, split_axis=0, tile_id_axis=None):\n",
    "  if tile_id_axis is None:\n",
    "    tile_id_axis = split_axis\n",
    "  array = jnp.moveaxis(array, tile_id_axis, split_axis)\n",
    "  new_shape = array.shape[:split_axis] + (-1,) + array.shape[split_axis+2:]\n",
    "  return jnp.reshape(array, new_shape)\n",
    "\n",
    "def shard(func):\n",
    "  def wrapper(state):\n",
    "    sharded_state = tree_util.tree_map(\n",
    "        lambda x: split_with_reshape(x, device_count), state)\n",
    "    sharded_result = func(sharded_state)\n",
    "    result = tree_util.tree_map(stack_with_reshape, sharded_result)\n",
    "    return result\n",
    "  return wrapper\n",
    "\n",
    "# Physics\n",
    "\n",
    "def shift(array, offset, axis):\n",
    "  index = slice(offset, None) if offset >= 0 else slice(None, offset)\n",
    "  sliced = slice_along_axis(array, index, axis)\n",
    "  padding = [(0, 0)] * array.ndim\n",
    "  padding[axis] = (-min(offset, 0), max(offset, 0))\n",
    "  return jnp.pad(sliced, padding, mode='constant', constant_values=0)\n",
    "\n",
    "def laplacian(array, step=1):\n",
    "  left = shift(array, +1, axis=0)\n",
    "  right = shift(array, -1, axis=0)\n",
    "  up = shift(array, +1, axis=1)\n",
    "  down = shift(array, -1, axis=1)\n",
    "  convolved = (left + right + up + down - 4 * array)\n",
    "  if step != 1:\n",
    "    convolved *= (1 / step ** 2)\n",
    "  return convolved\n",
    "\n",
    "def scalar_wave_equation(u, c=1, dx=1):\n",
    "  return c ** 2 * laplacian(u, dx)\n",
    "\n",
    "@jax.jit\n",
    "def leapfrog_step(state, dt=0.5, c=1):\n",
    "  # https://en.wikipedia.org/wiki/Leapfrog_integration\n",
    "  u, u_t = state\n",
    "  u_tt = scalar_wave_equation(u, c)\n",
    "  u_t = u_t + u_tt * dt\n",
    "  u = u + u_t * dt\n",
    "  return (u, u_t)\n",
    "\n",
    "# Time stepping\n",
    "\n",
    "def multi_step(state, count, dt=1/jnp.sqrt(2), c=1):\n",
    "  return lax.fori_loop(0, count, lambda i, s: leapfrog_step(s, dt, c), state)\n",
    "\n",
    "def multi_step_pmap(state, count, dt=1/jnp.sqrt(2), c=1, exchange_interval=1,\n",
    "                    save_interval=1):\n",
    "\n",
    "  def exchange_and_multi_step(state_padded):\n",
    "    c_padded = halo_exchange_padding(c, exchange_interval)\n",
    "    evolved = multi_step(state_padded, exchange_interval, dt, c_padded)\n",
    "    return halo_exchange_inplace(evolved, exchange_interval)\n",
    "\n",
    "  @shard\n",
    "  @partial(jax.pmap, axis_name='x')\n",
    "  def simulate_until_output(state):\n",
    "    stop = save_interval // exchange_interval\n",
    "    state_padded = halo_exchange_padding(state, exchange_interval)\n",
    "    advanced = lax.fori_loop(\n",
    "        0, stop, lambda i, s: exchange_and_multi_step(s), state_padded)\n",
    "    xi = exchange_interval\n",
    "    return tree_util.tree_map(lambda array: array[xi:-xi, ...], advanced)\n",
    "\n",
    "  results = [state]\n",
    "  for _ in range(count // save_interval):\n",
    "    state = simulate_until_output(state)\n",
    "    tree_util.tree_map(lambda x: x.copy_to_host_async(), state)\n",
    "    results.append(state)\n",
    "  results = jax.device_get(results)\n",
    "  return tree_util.tree_map(lambda *xs: np.stack([np.array(x) for x in xs]), *results)\n",
    "\n",
    "multi_step_jit = jax.jit(multi_step)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "_kMyXbkEeCu3"
   },
   "source": [
    "## Initial conditions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "usWkf2UgAYc5"
   },
   "outputs": [],
   "source": [
    "x = jnp.linspace(0, 8, num=8*1024, endpoint=False)\n",
    "y = jnp.linspace(0, 1, num=1*1024, endpoint=False)\n",
    "x_mesh, y_mesh = jnp.meshgrid(x, y, indexing='ij')\n",
    "\n",
    "# NOTE: smooth initial conditions are important, so we aren't exciting\n",
    "# arbitrarily high frequencies (that cannot be resolved)\n",
    "u = skimage.filters.gaussian(\n",
    "    ((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) < 0.1 ** 2,\n",
    "    sigma=1)\n",
    "\n",
    "# u = jnp.exp(-((x_mesh - 1/3) ** 2 + (y_mesh - 1/4) ** 2) / 0.1 ** 2)\n",
    "\n",
    "# u = skimage.filters.gaussian(\n",
    "#     (x_mesh > 1/3) & (x_mesh < 1/2) & (y_mesh > 1/3) & (y_mesh < 1/2),\n",
    "#     sigma=5)\n",
    "\n",
    "v = jnp.zeros_like(u)\n",
    "c = 1  # could also use a 2D array matching the mesh shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jWtIvDTfCx12"
   },
   "outputs": [],
   "source": [
    "u.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RVdPIRKmeNX3"
   },
   "source": [
    "## Test scaling from 1 to 8 chips"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "SPG6jvYSCaKQ"
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "# single TPU chip\n",
    "u_final, _ = multi_step_jit((u, v), count=2**13, c=c, dt=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "MpsDzNyC6OI0"
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "# 8x TPU chips, 4x more steps in roughly half the time!\n",
    "u_final, _ = multi_step_pmap(\n",
    "    (u, v), count=2**15, c=c, dt=0.5, exchange_interval=4, save_interval=2**15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "fPCwJ51FeBLR"
   },
   "outputs": [],
   "source": [
    "18.3 / (10.3 / 4)  # near linear scaling (8x would be perfect)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jYMn69z8eQ7O"
   },
   "source": [
    "## Save a bunch of outputs for a movie"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Uns4X34EPYgF"
   },
   "outputs": [],
   "source": [
    "%%time\n",
    "# save more outputs for a movie -- this is slow!\n",
    "u_final, _ = multi_step_pmap(\n",
    "    (u, v), count=2**15, c=c, dt=0.2, exchange_interval=4, save_interval=2**10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Hsjuk22Nbe9Z"
   },
   "outputs": [],
   "source": [
    "u_final.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LXqmz-CCdQjt"
   },
   "outputs": [],
   "source": [
    "u_final.nbytes / 1e9"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "cPvXrPUCPPtt"
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize=(18, 6))\n",
    "plt.axis('off')\n",
    "plt.imshow(u_final[-1].T, cmap='RdBu');"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "uqxGmttmgSsN"
   },
   "outputs": [],
   "source": [
    "fig, axes = plt.subplots(9, 1, figsize=(14, 14))\n",
    "[ax.axis('off') for ax in axes]\n",
    "axes[0].imshow(u_final[0].T, cmap='RdBu', aspect='equal', vmin=-1, vmax=1)\n",
    "for i in range(8):\n",
    "  axes[i+1].imshow(u_final[4*i+1].T / abs(u_final[4*i+1]).max(), cmap='RdBu', aspect='equal', vmin=-1, vmax=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "hNZrt590p6zr"
   },
   "outputs": [],
   "source": [
    "import matplotlib.cm\n",
    "import matplotlib.colors\n",
    "from PIL import Image\n",
    "\n",
    "def make_images(data, cmap='RdBu', vmax=None):\n",
    "  images = []\n",
    "  for frame in data:\n",
    "    if vmax is None:\n",
    "      this_vmax = np.max(abs(frame))\n",
    "    else:\n",
    "      this_vmax = vmax\n",
    "    norm = matplotlib.colors.Normalize(vmin=-this_vmax, vmax=this_vmax)\n",
    "    mappable = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap)\n",
    "    rgba = mappable.to_rgba(frame, bytes=True)\n",
    "    image = Image.fromarray(rgba, mode='RGBA')\n",
    "    images.append(image)\n",
    "  return images\n",
    "\n",
    "def save_movie(images, path, duration=100, loop=0, **kwargs):\n",
    "  images[0].save(path, save_all=True, append_images=images[1:],\n",
    "                 duration=duration, loop=loop, **kwargs)\n",
    "\n",
    "images = make_images(u_final[::, ::8, ::8].transpose(0, 2, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ZWEZWehFjboa"
   },
   "outputs": [],
   "source": [
    "# Show Movie\n",
    "proglog.default_bar_logger = partial(proglog.default_bar_logger, None)\n",
    "ImageSequenceClip([np.array(im) for im in images], fps=25).ipython_display()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "whG0OOepezuX"
   },
   "outputs": [],
   "source": [
    "# Save GIF.\n",
    "save_movie(images,'wave_movie.gif', duration=[2000]+[200]*(len(images)-2)+[2000])\n",
    "# The movie sometimes takes a second before showing up in the file system.\n",
    "import time; time.sleep(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "E1V7aS63vRAT"
   },
   "outputs": [],
   "source": [
    "# Download animation.\n",
    "try:\n",
    "    from google.colab import files\n",
    "except ImportError:\n",
    "    pass\n",
    "else:\n",
    "    files.download('wave_movie.gif')"
   ]
  }
 ],
 "metadata": {
  "accelerator": "TPU",
  "colab": {
   "collapsed_sections": [],
   "name": "JAX TPU wave equation",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
